{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "# Crystaltoolkit Relaxation Viewer\n",
    "\n",
    "This notebook shows how to visualize a CHGNet relaxation trajectory in a Plotly Dash app using Crystal Toolkit.\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "1",
   "metadata": {},
   "source": [
    "Running the last cell in this notebook should spin up a `dash` app that looks like this:\n",
    "\n",
    "![Crystaltoolkit Relaxation Viewer Screenshot](https://user-images.githubusercontent.com/30958850/230510639-2e659c9b-3a99-438b-9668-628299171602.png)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import chgnet  # noqa: F401\n",
    "except ImportError:\n",
    "    # install CHGNet with extra dependencies to run the dash app in this notebook\n",
    "    # https://github.com/materialsproject/crystaltoolkit\n",
    "    # (only needed on Google Colab or if you didn't install these packages yet)\n",
    "    !git clone --depth 1 https://github.com/CederGroupHub/chgnet\n",
    "    !pip install './chgnet[examples]'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pymatgen.core import Structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    from chgnet import ROOT\n",
    "\n",
    "    structure = Structure.from_file(f\"{ROOT}/examples/mp-18767-LiMnO2.cif\")\n",
    "except Exception:\n",
    "    from urllib.request import urlopen\n",
    "\n",
    "    url = \"https://github.com/CederGroupHub/chgnet/raw/-/examples/mp-18767-LiMnO2.cif\"\n",
    "    cif = urlopen(url).read().decode(\"utf-8\")\n",
    "    structure = Structure.from_str(cif, fmt=\"cif\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "original: ('Pmmn', 59)\n",
      "perturbed: ('P1', 1)\n"
     ]
    }
   ],
   "source": [
    "print(f\"original: {structure.get_space_group_info()}\")\n",
    "\n",
    "# perturb all atom positions by a small amount\n",
    "for site in structure:\n",
    "    site.coords += np.random.normal(size=3) * 0.3\n",
    "\n",
    "# stretch the cell by a small amount\n",
    "structure.scale_lattice(structure.volume * 1.1)\n",
    "\n",
    "print(f\"perturbed: {structure.get_space_group_info()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CHGNet initialized with 400,438 parameters\n",
      "CHGNet will run on mps\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/janosh/dev/chgnet/chgnet/model/composition_model.py:177: UserWarning: MPS: no support for int64 min/max ops, casting it to int32 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/ReduceOps.mm:1271.)\n",
      "  composition_fea = torch.bincount(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Step     Time          Energy         fmax\n",
      "*Force-consistent energies used in optimization.\n",
      "FIRE:    0 17:10:46      -39.349243*      93.7997\n",
      "FIRE:    1 17:10:46      -53.316616*      10.1811\n",
      "FIRE:    2 17:10:46      -53.377773*      15.6013\n",
      "FIRE:    3 17:10:47      -54.071163*      11.7506\n",
      "FIRE:    4 17:10:47      -54.818359*       5.3271\n",
      "FIRE:    5 17:10:48      -55.177044*       7.9573\n",
      "FIRE:    6 17:10:48      -55.661133*       9.9137\n",
      "FIRE:    7 17:10:49      -56.486736*       6.3212\n",
      "FIRE:    8 17:10:49      -57.129395*       4.4612\n",
      "FIRE:    9 17:10:50      -57.536762*       6.1146\n",
      "FIRE:   10 17:10:50      -57.886269*       2.9151\n",
      "FIRE:   11 17:10:51      -57.534672*       9.1977\n",
      "FIRE:   12 17:10:52      -57.731918*       6.6787\n",
      "FIRE:   13 17:10:52      -57.953892*       3.5996\n",
      "FIRE:   14 17:10:53      -58.058907*       2.4661\n",
      "FIRE:   15 17:10:53      -58.089428*       4.2591\n",
      "FIRE:   16 17:10:53      -58.099186*       4.0957\n",
      "FIRE:   17 17:10:54      -58.117531*       3.7736\n",
      "FIRE:   18 17:10:54      -58.142269*       3.3059\n",
      "FIRE:   19 17:10:54      -58.170528*       2.7055\n",
      "FIRE:   20 17:10:55      -58.199223*       1.9944\n",
      "FIRE:   21 17:10:55      -58.225574*       1.6872\n",
      "FIRE:   22 17:10:56      -58.247837*       1.6420\n",
      "FIRE:   23 17:10:56      -58.267609*       1.5754\n",
      "FIRE:   24 17:10:57      -58.286213*       1.5527\n",
      "FIRE:   25 17:10:58      -58.307659*       1.9544\n",
      "FIRE:   26 17:10:58      -58.336651*       2.1167\n",
      "FIRE:   27 17:10:58      -58.374733*       1.7955\n",
      "FIRE:   28 17:10:59      -58.417065*       1.2746\n",
      "FIRE:   29 17:10:59      -58.454224*       1.0842\n",
      "FIRE:   30 17:11:00      -58.482494*       1.4909\n",
      "FIRE:   31 17:11:00      -58.511620*       1.7936\n",
      "FIRE:   32 17:11:00      -58.551266*       1.3864\n",
      "FIRE:   33 17:11:00      -58.593399*       0.8655\n",
      "FIRE:   34 17:11:01      -58.626717*       1.1029\n",
      "FIRE:   35 17:11:01      -58.667667*       1.1856\n",
      "FIRE:   36 17:11:02      -58.714378*       0.7611\n",
      "FIRE:   37 17:11:02      -58.751740*       1.4115\n",
      "FIRE:   38 17:11:02      -58.798595*       0.6277\n",
      "FIRE:   39 17:11:03      -58.825634*       1.6683\n",
      "FIRE:   40 17:11:04      -58.860550*       0.6463\n",
      "FIRE:   41 17:11:04      -58.879448*       1.6940\n",
      "FIRE:   42 17:11:05      -58.889172*       1.0395\n",
      "FIRE:   43 17:11:05      -58.897785*       0.5511\n",
      "FIRE:   44 17:11:06      -58.900936*       0.9203\n",
      "FIRE:   45 17:11:06      -58.902317*       0.8079\n",
      "FIRE:   46 17:11:06      -58.904602*       0.5999\n",
      "FIRE:   47 17:11:06      -58.907150*       0.4713\n",
      "FIRE:   48 17:11:07      -58.909378*       0.4035\n",
      "FIRE:   49 17:11:07      -58.911129*       0.3782\n",
      "FIRE:   50 17:11:07      -58.912685*       0.4900\n",
      "FIRE:   51 17:11:08      -58.914524*       0.5925\n",
      "FIRE:   52 17:11:08      -58.917168*       0.5787\n",
      "FIRE:   53 17:11:09      -58.920650*       0.4262\n",
      "FIRE:   54 17:11:09      -58.924351*       0.2559\n",
      "FIRE:   55 17:11:10      -58.927425*       0.2542\n",
      "FIRE:   56 17:11:10      -58.930111*       0.4618\n",
      "FIRE:   57 17:11:10      -58.933582*       0.4244\n",
      "FIRE:   58 17:11:11      -58.937565*       0.2129\n",
      "FIRE:   59 17:11:11      -58.940548*       0.3162\n",
      "FIRE:   60 17:11:11      -58.943970*       0.3788\n",
      "FIRE:   61 17:11:11      -58.948582*       0.1709\n",
      "FIRE:   62 17:11:12      -58.952435*       0.4052\n",
      "FIRE:   63 17:11:12      -58.957687*       0.1760\n",
      "FIRE:   64 17:11:13      -58.961376*       0.4583\n",
      "FIRE:   65 17:11:13      -58.965767*       0.1113\n",
      "FIRE:   66 17:11:13      -58.967800*       0.4162\n",
      "FIRE:   67 17:11:13      -58.969425*       0.5236\n",
      "FIRE:   68 17:11:14      -58.970615*       0.2232\n",
      "FIRE:   69 17:11:14      -58.970974*       0.2027\n",
      "FIRE:   70 17:11:14      -58.971069*       0.1704\n",
      "FIRE:   71 17:11:14      -58.971218*       0.1290\n",
      "FIRE:   72 17:11:14      -58.971375*       0.1228\n",
      "FIRE:   73 17:11:15      -58.971523*       0.1153\n",
      "FIRE:   74 17:11:15      -58.971649*       0.1073\n",
      "FIRE:   75 17:11:15      -58.971767*       0.1409\n",
      "FIRE:   76 17:11:15      -58.971931*       0.1412\n",
      "FIRE:   77 17:11:15      -58.972164*       0.1046\n",
      "FIRE:   78 17:11:15      -58.972416*       0.0777\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "from chgnet.model import StructOptimizer\n",
    "\n",
    "trajectory = StructOptimizer().relax(structure)[\"trajectory\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "e_col = \"Energy (eV)\"\n",
    "force_col = \"Force (eV/Å)\"\n",
    "df_traj = pd.DataFrame(trajectory.energies, columns=[e_col])\n",
    "df_traj[force_col] = [\n",
    "    np.linalg.norm(force, axis=1).mean()  # mean of norm of force on each atom\n",
    "    for force in trajectory.forces\n",
    "]\n",
    "df_traj.index.name = \"step\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dft_energy=-59.09 eV (see https://materialsproject.org/materials/mp-18767)\n"
     ]
    }
   ],
   "source": [
    "mp_id = \"mp-18767\"\n",
    "\n",
    "dft_energy = -59.09\n",
    "print(f\"{dft_energy=:.2f} eV (see https://materialsproject.org/materials/{mp_id})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "        <iframe\n",
       "            width=\"100%\"\n",
       "            height=\"650\"\n",
       "            src=\"http://127.0.0.1:8050/\"\n",
       "            frameborder=\"0\"\n",
       "            allowfullscreen\n",
       "            \n",
       "        ></iframe>\n",
       "        "
      ],
      "text/plain": [
       "<IPython.lib.display.IFrame at 0x3d5938e10>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import crystal_toolkit.components as ctc\n",
    "import plotly.graph_objects as go\n",
    "from crystal_toolkit.settings import SETTINGS\n",
    "from dash import Dash, dcc, html\n",
    "from dash.dependencies import Input, Output\n",
    "from pymatgen.core import Structure\n",
    "\n",
    "app = Dash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)\n",
    "\n",
    "step_size = max(1, len(trajectory) // 20)  # ensure slider has max 20 steps\n",
    "slider = dcc.Slider(\n",
    "    id=\"slider\", min=0, max=len(trajectory) - 1, step=step_size, updatemode=\"drag\"\n",
    ")\n",
    "\n",
    "\n",
    "def plot_energy_and_forces(\n",
    "    df: pd.DataFrame, step: int, e_col: str, force_col: str, title: str\n",
    ") -> go.Figure:\n",
    "    \"\"\"Plot energy and forces as a function of relaxation step.\"\"\"\n",
    "    fig = go.Figure()\n",
    "    # energy trace = primary y-axis\n",
    "    fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode=\"lines\", name=\"Energy\"))\n",
    "    # get energy line color\n",
    "    line_color = fig.data[0].line.color\n",
    "\n",
    "    # forces trace = secondary y-axis\n",
    "    fig.add_trace(\n",
    "        go.Scatter(x=df.index, y=df[force_col], mode=\"lines\", name=\"Forces\", yaxis=\"y2\")\n",
    "    )\n",
    "\n",
    "    fig.update_layout(\n",
    "        template=\"plotly_white\",\n",
    "        title=title,\n",
    "        xaxis=dict(title=\"Relaxation Step\"),\n",
    "        yaxis=dict(title=e_col),\n",
    "        yaxis2=dict(title=force_col, overlaying=\"y\", side=\"right\"),\n",
    "        legend=dict(yanchor=\"top\", y=1, xanchor=\"right\", x=1),\n",
    "    )\n",
    "\n",
    "    # vertical line at the specified step\n",
    "    fig.add_vline(x=step, line=dict(dash=\"dash\", width=1))\n",
    "\n",
    "    # horizontal line for DFT final energy\n",
    "    anno = dict(text=\"DFT final energy\", yanchor=\"top\")\n",
    "    fig.add_hline(\n",
    "        y=dft_energy,\n",
    "        line=dict(dash=\"dot\", width=1, color=line_color),\n",
    "        annotation=anno,\n",
    "    )\n",
    "\n",
    "    return fig\n",
    "\n",
    "\n",
    "def make_title(spg_symbol: str, spg_num: int) -> str:\n",
    "    \"\"\"Return a title for the figure.\"\"\"\n",
    "    href = f\"https://materialsproject.org/materials/{mp_id}/\"\n",
    "    return f\"<a {href=}>{mp_id}</a> - {spg_symbol} ({spg_num})\"\n",
    "\n",
    "\n",
    "title = make_title(*structure.get_space_group_info())\n",
    "\n",
    "graph = dcc.Graph(\n",
    "    id=\"fig\",\n",
    "    figure=plot_energy_and_forces(df_traj, 0, e_col, force_col, title),\n",
    "    style={\"maxWidth\": \"50%\"},\n",
    ")\n",
    "\n",
    "struct_comp = ctc.StructureMoleculeComponent(id=\"structure\", struct_or_mol=structure)\n",
    "\n",
    "app.layout = html.Div(\n",
    "    [\n",
    "        html.H1(\n",
    "            \"Structure Relaxation Trajectory\", style=dict(margin=\"1em\", fontSize=\"2em\")\n",
    "        ),\n",
    "        html.P(\"Drag slider to see structure at different relaxation steps.\"),\n",
    "        slider,\n",
    "        html.Div([struct_comp.layout(), graph], style=dict(display=\"flex\", gap=\"2em\")),\n",
    "    ],\n",
    "    style=dict(margin=\"auto\", textAlign=\"center\", maxWidth=\"1200px\", padding=\"2em\"),\n",
    ")\n",
    "\n",
    "ctc.register_crystal_toolkit(app=app, layout=app.layout)\n",
    "\n",
    "\n",
    "@app.callback(\n",
    "    Output(struct_comp.id(), \"data\"), Output(graph, \"figure\"), Input(slider, \"value\")\n",
    ")\n",
    "def update_structure(step: int) -> tuple[Structure, go.Figure]:\n",
    "    \"\"\"Update the structure displayed in the StructureMoleculeComponent and the\n",
    "    dashed vertical line in the figure when the slider is moved.\n",
    "    \"\"\"\n",
    "    lattice = trajectory.cells[step]\n",
    "    coords = trajectory.atom_positions[step]\n",
    "    structure.lattice = lattice  # update structure in place for efficiency\n",
    "    assert len(structure) == len(coords)\n",
    "    for site, coord in zip(structure, coords, strict=True):\n",
    "        site.coords = coord\n",
    "\n",
    "    title = make_title(*structure.get_space_group_info())\n",
    "    fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title)\n",
    "\n",
    "    return structure, fig\n",
    "\n",
    "\n",
    "app.run(height=800, use_reloader=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
